{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "#导入包\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from gensim.models import Word2Vec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "#读取数据集\n",
    "train = pd.read_csv('./data/train_set.csv')\n",
    "test = pd.read_csv('./data/test_set.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>article</th>\n",
       "      <th>word_seg</th>\n",
       "      <th>class</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>7368 1252069 365865 755561 1044285 129532 1053...</td>\n",
       "      <td>816903 597526 520477 1179558 1033823 758724 63...</td>\n",
       "      <td>14</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>581131 165432 7368 957317 1197553 570900 33659...</td>\n",
       "      <td>90540 816903 441039 816903 569138 816903 10343...</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>7368 87936 40494 490286 856005 641588 145611 1...</td>\n",
       "      <td>816903 1012629 957974 1033823 328210 947200 65...</td>\n",
       "      <td>12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>299237 760651 299237 887082 159592 556634 7489...</td>\n",
       "      <td>563568 1239563 680125 780219 782805 1033823 19...</td>\n",
       "      <td>13</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>7368 7368 7368 865510 7368 396966 995243 37685...</td>\n",
       "      <td>816903 816903 816903 139132 816903 312320 1103...</td>\n",
       "      <td>12</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   id                                            article  \\\n",
       "0   0  7368 1252069 365865 755561 1044285 129532 1053...   \n",
       "1   1  581131 165432 7368 957317 1197553 570900 33659...   \n",
       "2   2  7368 87936 40494 490286 856005 641588 145611 1...   \n",
       "3   3  299237 760651 299237 887082 159592 556634 7489...   \n",
       "4   4  7368 7368 7368 865510 7368 396966 995243 37685...   \n",
       "\n",
       "                                            word_seg  class  \n",
       "0  816903 597526 520477 1179558 1033823 758724 63...     14  \n",
       "1  90540 816903 441039 816903 569138 816903 10343...      3  \n",
       "2  816903 1012629 957974 1033823 328210 947200 65...     12  \n",
       "3  563568 1239563 680125 780219 782805 1033823 19...     13  \n",
       "4  816903 816903 816903 139132 816903 312320 1103...     12  "
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/hcq/miniconda3/envs/python3/lib/python3.6/site-packages/ipykernel_launcher.py:1: FutureWarning: Sorting because non-concatenation axis is not aligned. A future version\n",
      "of pandas will change to not sort by default.\n",
      "\n",
      "To accept the future behavior, pass 'sort=False'.\n",
      "\n",
      "To retain the current behavior and silence the warning, pass 'sort=True'.\n",
      "\n",
      "  \"\"\"Entry point for launching an IPython kernel.\n"
     ]
    }
   ],
   "source": [
    "df=pd.concat([train,test],axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>article</th>\n",
       "      <th>class</th>\n",
       "      <th>id</th>\n",
       "      <th>word_seg</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>7368 1252069 365865 755561 1044285 129532 1053...</td>\n",
       "      <td>14.0</td>\n",
       "      <td>0</td>\n",
       "      <td>816903 597526 520477 1179558 1033823 758724 63...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>581131 165432 7368 957317 1197553 570900 33659...</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1</td>\n",
       "      <td>90540 816903 441039 816903 569138 816903 10343...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>7368 87936 40494 490286 856005 641588 145611 1...</td>\n",
       "      <td>12.0</td>\n",
       "      <td>2</td>\n",
       "      <td>816903 1012629 957974 1033823 328210 947200 65...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>299237 760651 299237 887082 159592 556634 7489...</td>\n",
       "      <td>13.0</td>\n",
       "      <td>3</td>\n",
       "      <td>563568 1239563 680125 780219 782805 1033823 19...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>7368 7368 7368 865510 7368 396966 995243 37685...</td>\n",
       "      <td>12.0</td>\n",
       "      <td>4</td>\n",
       "      <td>816903 816903 816903 139132 816903 312320 1103...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                             article  class  id  \\\n",
       "0  7368 1252069 365865 755561 1044285 129532 1053...   14.0   0   \n",
       "1  581131 165432 7368 957317 1197553 570900 33659...    3.0   1   \n",
       "2  7368 87936 40494 490286 856005 641588 145611 1...   12.0   2   \n",
       "3  299237 760651 299237 887082 159592 556634 7489...   13.0   3   \n",
       "4  7368 7368 7368 865510 7368 396966 995243 37685...   12.0   4   \n",
       "\n",
       "                                            word_seg  \n",
       "0  816903 597526 520477 1179558 1033823 758724 63...  \n",
       "1  90540 816903 441039 816903 569138 816903 10343...  \n",
       "2  816903 1012629 957974 1033823 328210 947200 65...  \n",
       "3  563568 1239563 680125 780219 782805 1033823 19...  \n",
       "4  816903 816903 816903 139132 816903 312320 1103...  "
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 训练word2vec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "[[\"欢迎\",\"来到\",\"深度\",\"之眼\"],[\"欢迎\",\"来到\",\"深度\",\"之眼\"],[\"欢迎\",\"来到\",\"深度\",\"之眼\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "816903 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "sentences=[]\n",
    "for document in df['word_seg'].tolist():\n",
    "    sentences.append(document.split(\" \"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['816903', '597526', '520477', '1179558', '1033823']"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sentences[0][:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Word2Vec(sentences=sentences,\n",
    "     size=200,#维度\n",
    "     alpha=0.025, #默认\n",
    "     window=5, #默认\n",
    "     min_count=2,#2，3\n",
    "     sample=0.001,#\n",
    "     seed=2018, #\n",
    "     workers=11, #线程\n",
    "     min_alpha=0.0001, \n",
    "     sg=0, #cbow\n",
    "     hs=0, #负采样\n",
    "     negative=5,#负采样个数\n",
    "     ns_exponent=0.75, \n",
    "     cbow_mean=1,#求和再取平均\n",
    "     iter=10 #10到15\n",
    "     )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## 保存word2vec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save(\"./word2vec/word2vec_word_200.model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Word2Vec.load(\"./word2vec/word2vec_word_200.model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(200,)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#查看单词的向量\n",
    "model.wv['816903'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Help on method most_similar in module gensim.models.base_any2vec:\n",
      "\n",
      "most_similar(positive=None, negative=None, topn=10, restrict_vocab=None, indexer=None) method of gensim.models.word2vec.Word2Vec instance\n",
      "    Deprecated, use self.wv.most_similar() instead.\n",
      "    \n",
      "    Refer to the documentation for :meth:`~gensim.models.keyedvectors.WordEmbeddingsKeyedVectors.most_similar`.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "help(model.most_similar)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/hcq/miniconda3/envs/python3/lib/python3.6/site-packages/ipykernel_launcher.py:2: DeprecationWarning: Call to deprecated `most_similar` (Method will be removed in 4.0.0, use self.wv.most_similar() instead).\n",
      "  \n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[('12875', 0.8677932620048523),\n",
       " ('679169', 0.8625671863555908),\n",
       " ('90540', 0.841310977935791),\n",
       " ('425105', 0.8043540716171265),\n",
       " ('866203', 0.7445841431617737),\n",
       " ('122513', 0.7241939902305603),\n",
       " ('1234861', 0.7100560069084167),\n",
       " ('85838', 0.7024739980697632),\n",
       " ('1189755', 0.6224364638328552),\n",
       " ('426716', 0.5778474807739258),\n",
       " ('816903', 0.5615671873092651),\n",
       " ('797828', 0.557973325252533),\n",
       " ('1254728', 0.5530299544334412),\n",
       " ('11177', 0.546566367149353),\n",
       " ('850976', 0.5452205538749695),\n",
       " ('48896', 0.5422906875610352),\n",
       " ('903604', 0.5324429273605347),\n",
       " ('1146147', 0.5293028354644775),\n",
       " ('1200328', 0.527854859828949),\n",
       " ('1104318', 0.5183314085006714)]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#和这个单词最相似的单词\n",
    "model.most_similar(\"700659\",topn=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6617146"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#计算两个单词之间相似性\n",
    "model.wv.similarity(\"816903\",\"1226448\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#看词表\n",
    "model.wv.vocab.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "#迭代模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sentences_next=[]\n",
    "for document in test['word_seg'].tolist():\n",
    "    sentences_next.append(document.split(\" \"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.train(sentences=sentences_next, total_examples=model.corpus_count,  epochs=model.iter) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
